{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Generate images with your fine-tuned Stable Diffusion model\n",
    "\n",
    "You should use this notebook to interactively generate images, after you've already fine-tuned a stable diffusion model and have a model checkpoint available to load. See the README for instructions."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# TODO: Change this to the path of your fine-tuned model checkpoint!\n",
    "# This is the $TUNED_MODEL_DIR variable defined in the run script.\n",
    "TUNED_MODEL_PATH = \"/tmp/model-tuned\"\n",
    "# TODO: Set the following variables if you fine-tuned with LoRA.\n",
    "ORIG_MODEL_PATH = \"/tmp/model-orig/models--CompVis--stable-diffusion-v1-4/snapshots/b95be7d6f134c3a9e62ee616f310733567f069ce/\"\n",
    "LORA_WEIGHTS_DIR = \"/tmp/model-tuned\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "First, load the model checkpoint as a HuggingFace 🤗 pipeline.\n",
    "Load the model onto a GPU and define a function to generate images from a text prompt."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from os import environ\n",
    "\n",
    "import torch\n",
    "from diffusers import DiffusionPipeline\n",
    "\n",
    "from dreambooth.generate_utils import load_lora_weights, get_pipeline\n",
    "\n",
    "pipeline = None\n",
    "\n",
    "def on_full_ft():\n",
    "    global pipeline\n",
    "    pipeline = get_pipeline(TUNED_MODEL_PATH)\n",
    "    pipeline.to(\"cuda\")\n",
    "    \n",
    "def on_lora_ft():\n",
    "    assert ORIG_MODEL_PATH\n",
    "    assert LORA_WEIGHTS_DIR\n",
    "    global pipeline\n",
    "    pipeline = get_pipeline(ORIG_MODEL_PATH, LORA_WEIGHTS_DIR)\n",
    "    pipeline.to(\"cuda\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "def generate(\n",
    "    pipeline: DiffusionPipeline,\n",
    "    prompt: str,\n",
    "    img_size: int = 512,\n",
    "    num_samples: int = 1,\n",
    ") -> list:\n",
    "    return pipeline([prompt] * num_samples, height=img_size, width=img_size).images\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Try out your model!\n",
    "\n",
    "Now, play with your fine-tuned diffusion model through this simple GUI."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import time\n",
    "import ipywidgets as widgets\n",
    "from IPython.display import display, clear_output\n",
    "\n",
    "# TODO: When giving prompts, make sure to include your subject's unique identifier,\n",
    "# as well as its class name.\n",
    "# For example, if your subject's unique identifier is \"unqtkn\" and is a dog,\n",
    "# you can give the prompt \"photo of a unqtkn dog on the beach\".\n",
    "\n",
    "# IPython GUI Layouts\n",
    "\n",
    "output = widgets.Output()\n",
    "toggle_buttons = widgets.ToggleButtons(\n",
    "    options=[\"Full fine-tuning\",\"LoRA fine-tuning\"],\n",
    "    disabled=False,\n",
    "    button_style='', # 'success', 'info', 'warning', 'danger' or ''\n",
    "    value=None,\n",
    "    # layout=widgets.Layout(width='100px')\n",
    ")\n",
    "\n",
    "def toggle_callback(change):\n",
    "    with output:\n",
    "        clear_output()\n",
    "        if change[\"new\"] == \"Full fine-tuning\":\n",
    "            on_full_ft()\n",
    "        else:\n",
    "            on_lora_ft()\n",
    "        \n",
    "toggle_buttons.observe(toggle_callback, names=\"value\")\n",
    "    \n",
    "input_text = widgets.Text(\n",
    "    value=\"photo of a unqtkn dog on the beach\",\n",
    "    placeholder=\"\",\n",
    "    description=\"Prompt:\",\n",
    "    disabled=False,\n",
    "    layout=widgets.Layout(width=\"500px\"),\n",
    ")\n",
    "\n",
    "button = widgets.Button(description=\"Generate!\")\n",
    "\n",
    "# Define button click event\n",
    "def on_button_clicked(b):\n",
    "    with output:\n",
    "        clear_output()\n",
    "        print(\"Generating images...\")\n",
    "        print(\n",
    "            \"(The output image may be completely black if it's filtered by \"\n",
    "            \"HuggingFace diffusers safety checkers.)\"\n",
    "        )\n",
    "        start_time = time.time()\n",
    "        images = generate(pipeline=pipeline, prompt=input_text.value, num_samples=2)\n",
    "        display(*images)\n",
    "        finish_time = time.time()\n",
    "        print(f\"Completed in {finish_time - start_time} seconds.\")\n",
    "\n",
    "button.on_click(on_button_clicked)\n",
    "\n",
    "# Display the widgets\n",
    "display(toggle_buttons, widgets.HBox([input_text, button]), output)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# release memory properly\n",
    "del pipeline \n",
    "torch.cuda.empty_cache()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
